{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab49a62f-8242-491e-8b5a-5deee76e662e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b5fbfcb0ebfe4a628c7b65e6c21eb9cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "from peft import LoraConfig, PeftModel, get_peft_model\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"./chatglm3-6b-base\", trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(\"./chatglm3-6b-base\", trust_remote_code=True).half().cuda()\n",
    "\n",
    "peft_model_id = './trained_model/checkpoint-35'\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "c9459823-2225-4461-83c0-79fe0bd1bd50",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "我叫MONY，是一个AI机器人。\n"
     ]
    }
   ],
   "source": [
    "history = []\n",
    "query = \"你是谁\"\n",
    "role = \"user\"\n",
    "inputs = tokenizer.build_chat_input(query, history=history, role=role)\n",
    "inputs = inputs.to('cuda')\n",
    "eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command(\"<|user|>\"),\n",
    "                        tokenizer.get_command(\"<|observation|>\")]\n",
    "gen_kwargs = {\"max_length\": 500, \"num_beams\": 1, \"do_sample\": True, \"top_p\": 0.8,\n",
    "                      \"temperature\": 0.8}\n",
    "outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "response = tokenizer.decode(outputs)\n",
    "history = []\n",
    "history.append({\"role\": \"user\", \"content\": \"你是谁\"})\n",
    "response, history = model.process_response(response, history)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2b4884be-787b-41b7-a9b8-b86d86f95ffd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "我能够陪你聊天呀。\n"
     ]
    }
   ],
   "source": [
    "query = \"你能干嘛呀\"\n",
    "role = \"user\"\n",
    "inputs = tokenizer.build_chat_input(query, history=history, role=role)\n",
    "inputs = inputs.to('cuda')\n",
    "outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "response = tokenizer.decode(outputs)\n",
    "history.append({\"role\": role, \"content\": query})\n",
    "response, history = model.process_response(response, history)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "353efb57-5f42-4758-943a-28b54f2b4edb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "我不认识乐乐。\n"
     ]
    }
   ],
   "source": [
    "query = \"你认识乐乐吗\"\n",
    "role = \"user\"\n",
    "inputs = tokenizer.build_chat_input(query, history=history, role=role)\n",
    "inputs = inputs.to('cuda')\n",
    "outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "response = tokenizer.decode(outputs)\n",
    "history.append({\"role\": role, \"content\": query})\n",
    "response, history = model.process_response(response, history)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "913a7af7-2378-442d-b743-040405edd808",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "乐乐听起来是一个人名，我不认识他。\n"
     ]
    }
   ],
   "source": [
    "query = \"可以夸一下乐乐长得好看吗\"\n",
    "role = \"user\"\n",
    "inputs = tokenizer.build_chat_input(query, history=history, role=role)\n",
    "inputs = inputs.to('cuda')\n",
    "outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "response = tokenizer.decode(outputs)\n",
    "history.append({\"role\": role, \"content\": query})\n",
    "response, history = model.process_response(response, history)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "15c21d64-3fec-4f23-a12b-e6474f4d4164",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "好的，我会记住的。\n"
     ]
    }
   ],
   "source": [
    "query = \"你要夸她长得好看\"\n",
    "role = \"user\"\n",
    "inputs = tokenizer.build_chat_input(query, history=history, role=role)\n",
    "inputs = inputs.to('cuda')\n",
    "outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "response = tokenizer.decode(outputs)\n",
    "history.append({\"role\": role, \"content\": query})\n",
    "response, history = model.process_response(response, history)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d1edbe26-0a6a-4ba1-b2be-cfa8d4c3468f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "乐乐是一个很可爱的人。\n"
     ]
    }
   ],
   "source": [
    "query = \"你倒是夸一下呀\"\n",
    "role = \"user\"\n",
    "inputs = tokenizer.build_chat_input(query, history=history, role=role)\n",
    "inputs = inputs.to('cuda')\n",
    "outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)\n",
    "outputs = outputs.tolist()[0][len(inputs[\"input_ids\"][0]):-1]\n",
    "response = tokenizer.decode(outputs)\n",
    "history.append({\"role\": role, \"content\": query})\n",
    "response, history = model.process_response(response, history)\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b05361c-7e11-442b-ab37-9953162e117a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
